# Energy minimization, save real space relaxation displacements, curl of the displacements, and relaxed misfit energy to file
# Note: need to create a data/ folder in the current working directory for the data to be saved

using Distributed
using TimerOutputs
using FFTW
using LinearAlgebra

using Optim
using LineSearches
using Printf

include("Realspace.jl")

N = 18 # discretization. Need to increase to 54 to reproduce the result of the manuscript 

system = "triG"


θ1 = deg2rad(1.5);
θ2 = deg2rad(0.0);       # The second layer is fixed as a reference.
θ3 = deg2rad(1.69);
println("\n=============================================================\n")
println(rad2deg(θ1))
println(rad2deg(θ3))
println(N)

global to = TimerOutput()

@timeit to "FFTW wisdom i/o" isfile("FFTWwisdom2.jld") && FFTW.import_wisdom("FFTWwisdom2.jld")
FFTW.set_num_threads(8)
BLAS.set_num_threads(8)

pcs = addprocs(4)
println("\n=============================================================\n        Using "
        * string(nworkers()) *  " workers... \n")


@timeit to "GrapheneParameters.jl" @everywhere include("GrapheneParameters.jl")
@timeit to "Trilayers.jl" include("Trilayers.jl")

@timeit to "Trilayer setup" tlg = Trilayer(l*E0, l*P0, θ1, θ3, K, G)
@timeit to "Hull setup" hull = Hull(tlg, N)
@timeit to "FFTW wisdom i/o" FFTW.export_wisdom("FFTWwisdom.jld")
hN = hull.hN;

f(u::Array{ComplexF64,6})                         = @timeit to "f(u)" Energy(u, hull)
g!( storage::Array{ComplexF64,6},
    u::Array{ComplexF64,6})                       = @timeit to "g!(storage, u)" Gradient!(storage, u, hull)
fg!(    F::Union{Nothing, Float64},
        G::Union{Nothing, Array{ComplexF64,6}},
        u::Array{ComplexF64,6})                   = @timeit to "fg!(F,G,u)" EnergyGradient!(F, G, u, hull)

# @timeit to "Initialization" guess                   = zeros(ComplexF64, 2,hN,N,N,N,3)
@timeit to "Initialization" guess                   = zeros(ComplexF64, 2,hN,N,N,N,3) + 0. * rand(ComplexF64, (2,hN,N,N,N,3))

# Run optimization algorithm
@timeit to "function setup" d = Optim.only_fg!(fg!)
@timeit to "method" method = LBFGS(; m = 5, P = hull.Precon_Elastic, scaleinvH0 = false)
@timeit to "options" options = Optim.Options(
            iterations = 10000, x_tol = 1e-4, f_tol = 1e-11, g_tol = 1e-8, allow_f_increases = true,
            show_trace = true, store_trace = false, show_every = 10 )
@timeit to "optimize" results = optimize(d, guess, method, options)
display(results)
display(to)

begin
  u = similar(results.minimizer)
  u = copyto!(u, results.minimizer)
  ur = hull.iplan * u
  Ur = similar(ur)
  Ur[1,:,:,:,:,1] =   permutedims(ur[1,:,:,:,:,3], [2,1,4,3])
  Ur[1,:,:,:,:,3] =   permutedims(ur[1,:,:,:,:,1], [2,1,4,3])
  Ur[1,:,:,:,:,2] =   permutedims(ur[1,:,:,:,:,2], [4,3,2,1])

  Ur[2,:,:,:,:,1] = - permutedims(ur[2,:,:,:,:,3], [2,1,4,3])
  Ur[2,:,:,:,:,3] = - permutedims(ur[2,:,:,:,:,1], [2,1,4,3])
  Ur[2,:,:,:,:,2] = - permutedims(ur[2,:,:,:,:,2], [4,3,2,1])
  U = hull.plan * Ur
  u = hull.plan * ur
end


∇ur = ∇(u, hull)
∇Ur = ∇(U, hull)

outpath = string("$(@__DIR__)", "/data/")

@timeit to "GrapheneParameters.jl" @everywhere include("GrapheneParameters.jl")

write(string("$(@__DIR__)", "/data/", system, "_data_", @sprintf "%.2f_%.2f_%d.jld" rad2deg(θ1) rad2deg(θ3) N), N, θ1, θ3, tlg.E, tlg.P, K, G)
write(string("$(@__DIR__)", "/data/", system, "_minimizer_", @sprintf "%.2f_%.2f_%d.jld" rad2deg(θ1) rad2deg(θ3) N), u, results.minimum)
write(string("$(@__DIR__)", "/data/", system, "_gradient_", @sprintf "%.2f_%.2f_%d.jld" rad2deg(θ1) rad2deg(θ3) N), ∇ur[1,1], ∇ur[1,2], ∇ur[2,1], ∇ur[2,2])

rmprocs(pcs)
nothing


# convert from configuration space to real space and save real space displacement to file
θ1 = rad2deg(θ1)
θ3 = rad2deg(θ3)
scale = 20

tlg, u, ∇u = Read(N, θ1, θ3, "/data/", system)


itp_u = Fields([u[1,:,:,:,:,:], u[2,:,:,:,:,:]], N)
itp_∇u = Fields(∇u, N)


E = ( abs(tlg.θ[1]) < abs(tlg.θ[3]) ?
            inv(inv(tlg.tE[1]) - inv(tlg.tE[2])) :
            inv(inv(tlg.tE[3]) - inv(tlg.tE[2])) )

Lx = maximum(abs.(E[1,:]))
Ly = maximum(abs.(E[2,:]))
ω = Configuration([0.0;0.0], [0.0;0.0], [0.0;0.0], tlg)

Lx = Lx * scale
Ly = Ly * scale

# create real space (x,y) grid 
n = 40
X = range(-2*Lx, 2*Lx, length=4*n)
Y = range(-Ly/2, 5*Ly/2, length=5*n)
Grid = cat( reshape(X, (1,4*n,1)) .* ones(1,1,5*n),
            ones(1,4*n,1) .* reshape(Y, (1,1,5*n)), dims=1)
Grid = reshape(Grid, (2, 20*n^2))
# @time exactU = evaluateFields(Grid, ϕ, ω, N, tlg)
@time U = interpolateFields(Grid, itp_u, ω, N, tlg)

xarr = X .* ones(1,5*n)
yarr = ones(4*n,1) .* reshape(Y, (1,5*n))
xarr = xarr[:]
yarr = yarr[:]

open(string(outpath, system, "_q12_", round(θ1; digits=2), "deg_q23_",
    round(θ3; digits=2), "deg_N_", N, "_n_", n, "_scale_", scale, "_disp.txt"), "w") do f
    write(f, "x,y,u1x,u1y,u2x,u2y,u3x,u3y \n")
    for i in 1:length(xarr)
        rx = xarr[i]
        ry = yarr[i]
        r1 = U[1, 1][i]
        r2 = U[2, 1][i]
        r3 = U[1, 2][i]
        r4 = U[2, 2][i]
        r5 = U[1, 3][i]
        r6 = U[2, 3][i]
        write(f, "$rx, $ry, $r1, $r2, $r3, $r4, $r5, $r6\n")
    end
end



hull = Hull(tlg, N)
Γ0 = Array(hull.Γ0)
permutations = Array(hull.permutations)
invE = hull.tl.invtE

# unrelaxed gsfe
gsfe0 = GSFE(Γ0)

u_tmp = reshape(u, (2, N^4, 3))


# relaxed energy
# L1 + L2
shifts = zeros(Float64, (2,N^4))
uj = u_tmp[:,:,1]
u2 = u_tmp[:,:,2]
permutation = permutations[:,1]
for i = 1:2
    if i == 1
        global misfit12 = zeros(Float64, (N^4))
    end

    invEi = invE[i]
    shifts[1,:] = Γ0[1,:] + invEi[1,1]*(u2[1, permutation] - uj[1,:]) +
                                  invEi[1,2]*(u2[2, permutation] - uj[2,:])
    shifts[2,:] = Γ0[2,:] + invEi[2,1]*(u2[1, permutation] - uj[1,:]) +
                                  invEi[2,2]*(u2[2, permutation] - uj[2,:])
    misfit12 = misfit12 + 0.5*GSFE(shifts)
end

# L2 + L3
uj = u_tmp[:,:,2]
u2 = u_tmp[:,:,3]
permutation = permutations[:,2]
for i = 1:2
    if i == 1
        global misfit23 = zeros(Float64, (N^4))
    end
    invEi = invE[i+1]
    shifts[1,:] = Γ0[1,:] + invEi[1,1]*(u2[1, permutation] - uj[1,:]) +
                                  invEi[1,2]*(u2[2, permutation] - uj[2,:])
    shifts[2,:] = Γ0[2,:] + invEi[2,1]*(u2[1, permutation] - uj[1,:]) +
                                  invEi[2,2]*(u2[2, permutation] - uj[2,:])
    misfit23 += 0.5*GSFE(shifts)
end

Γx = reshape(Γ0[1, :], (N, N, N, N))
Γy = reshape(Γ0[2, :], (N, N, N, N))
b = zeros(Float64, (2, N, N))
b[1, :, :] = Γx[:, :, 1, 1]
b[2, :, :] = Γy[:, :, 1, 1]
b = reshape(b, (2, N^2))
b = tlg.E * b
bx = reshape(b[1, :], (N, N))
by = reshape(b[2, :], (N, N))

misfit12 = reshape(misfit12, (N, N, N, N))
misfit23 = reshape(misfit23, (N, N, N, N))


misfit12 = reshape(misfit12, (N,N,N,N,1))
misfit23 = reshape(misfit23, (N,N,N,N,1))
misfit = zeros((N,N,N,N,3))
misfit[:,:,:,:,1]=misfit12
misfit[:,:,:,:,3]=misfit23
itp_misfit = Fields([misfit], N)

itp_u = Fields([u[1,:,:,:,:,:], u[2,:,:,:,:,:]], N)
itp_∇u = Fields(∇u, N)


E = ( abs(tlg.θ[1]) < abs(tlg.θ[3]) ?
            inv(inv(tlg.tE[1]) - inv(tlg.tE[2])) :
            inv(inv(tlg.tE[3]) - inv(tlg.tE[2])) )

Lx = maximum(abs.(E[1,:]))
Ly = maximum(abs.(E[2,:]))
ω = Configuration([0.0;0.0], [0.0;0.0], [0.0;0.0], tlg)

Lx = Lx * scale
Ly = Ly * scale

n = 500
X = range(-2*Lx, 2*Lx, length=4*n)
Y = range(-Ly/2, 5*Ly/2, length=5*n)
Grid = cat( reshape(X, (1,4*n,1)) .* ones(1,1,5*n),
            ones(1,4*n,1) .* reshape(Y, (1,1,5*n)), dims=1)
Grid = reshape(Grid, (2, 20*n^2))
@time U = interpolateFields(Grid, itp_u, ω, N, tlg)
@time E_interp = interpolateFields(Grid, itp_misfit, ω, N, tlg)


xarr = X .* ones(1,5*n)
yarr = ones(4*n,1) .* reshape(Y, (1,5*n))

E_interp12 = reshape(E_interp[1], size(xarr))
E_interp23 = reshape(E_interp[3], size(xarr))

# save the misfit energy landscape to file

xarr = xarr[:]
yarr = yarr[:]
E_interp12 = E_interp12[:]
E_interp23 = E_interp23[:]

open(string(outpath, system, "_q12_", round(θ1; digits=2), "deg_q23_",
    round(θ3; digits=2), "deg_N_", N, "_scale_", scale, "_disp_energy.txt"), "w") do f
    write(f, "x,y,misfit12,misfit23 \n")
    for i in 1:length(E_interp12)
        rx = xarr[i]
        ry = yarr[i]
        r7 = E_interp12[i]
        r8 = E_interp23[i]
        write(f, "$rx, $ry, $r7, $r8\n")
    end
end


# Save the curvature of u to file
@time ∇U = interpolateFields(Grid, itp_∇u, ω, N, tlg)


open(string(outpath, system, "_q12_", round(θ1; digits=2), "deg_q23_",
    round(θ3; digits=2), "deg_N_", N, "_scale_", scale, "_xarr.txt"), "w") do f
    for i in 1:length(X)
        r1 = X[i]
        write(f, "$r1\n")
    end
end

open(string(outpath, system, "_q12_", round(θ1; digits=2), "deg_q23_",
    round(θ3; digits=2), "deg_N_", N, "_scale_", scale, "_yarr.txt"), "w") do f
    for i in 1:length(Y)
        r1 = Y[i]
        write(f, "$r1\n")
    end
end


grad1 = ∇U[2,1,1] .- ∇U[1,2,1]
grad2 = ∇U[2,1,2] .- ∇U[1,2,2]
grad3 = ∇U[2,1,3] .- ∇U[1,2,3]
open(string(outpath, system, "_q12_", round(θ1; digits=2), "deg_q23_",
    round(θ3; digits=2), "deg_N_", N, "_scale_", scale, "_curl.txt"), "w") do f
    for i in 1:length(X)*length(Y)
        r1 = grad1[i]
        r2 = grad2[i]
        r3 = grad3[i]
        write(f, "$r1, $r2, $r3\n")
    end
end
